package com.alibaba.alink.operator.common.statistics.basicstatistic;

import org.apache.flink.types.Row;

import com.alibaba.alink.common.linalg.DenseVector;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.common.linalg.Vector;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.common.utils.PrettyDisplayUtils;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * It is summary result of sparse vector.
 * You can get vectorSize, mean, variance, and other statistics from this class,
 * and It will return sparse vector when get statistics.
 */
public class SparseVectorSummary extends BaseVectorSummary {

	private static final long serialVersionUID = 7656305964922586155L;
	int colNum = -1;

	Map <Integer, VectorStatCol> cols = new HashMap();

	private static final String[] OUT_COL_NAMES = new String[] {"count",
		"sum", "mean", "variance", "stdDev", "min", "max", "normL1", "normL2"};

	/**
	 * It will generated by summary.
	 */
	SparseVectorSummary() {
	}

	@Override
	public String toString() {
		int vectorSize = this.vectorSize();
		String[] colNames = new String[vectorSize];
		for (int i = 0; i < vectorSize; i++) {
			colNames[i] = String.valueOf(i);
		}

		Object[][] data = new Object[vectorSize][OUT_COL_NAMES.length];
		for (int i = 0; i < vectorSize; i++) {
			data[i][0] = count;
			data[i][1] = sum(i);
			data[i][2] = mean(i);
			data[i][3] = variance(i);
			data[i][4] = standardDeviation(i);
			data[i][5] = min(i);
			data[i][6] = max(i);
			data[i][7] = normL1(i);
			data[i][8] = normL2(i);
		}

		return "SparseVectorSummary:" +
			"\n" +
			PrettyDisplayUtils.displayTable(data, vectorSize, OUT_COL_NAMES.length, colNames, OUT_COL_NAMES, "id", 100, 100);
	}

	@Override
	public int vectorSize() {
		int maxIndx = -1;
		Integer[] indics = cols.keySet().toArray(new Integer[0]);
		for (int i = 0; i < indics.length; i++) {
			if (maxIndx < indics[i]) {
				maxIndx = indics[i];
			}
		}
		colNum = Math.max(colNum, maxIndx + 1);
		return colNum;
	}

	/**
	 * sum of each dimension.
	 */
	@Override
	public Vector sum() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = cols.get(indices[i]).sum;
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * mean of each dimension.
	 */
	@Override
	public Vector mean() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = cols.get(indices[i]).mean(count);
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * variance of each dimension.
	 */
	@Override
	public Vector variance() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = cols.get(indices[i]).variance(count);
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * standardDeviation of each dimension.
	 */
	@Override
	public Vector standardDeviation() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = cols.get(indices[i]).standardDeviation(count);
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * min of each dimension.
	 */
	@Override
	public Vector min() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			VectorStatCol statCol = cols.get(indices[i]);
			data[i] = 0;
			if (statCol.numNonZero > 0) {
				data[i] = Math.min(statCol.min, 0);
			}
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * max of each dimension.
	 */
	@Override
	public Vector max() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			VectorStatCol statCol = cols.get(indices[i]);
			data[i] = 0;
			if (statCol.numNonZero > 0) {
				data[i] = Math.max(statCol.max, 0);
			}
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * normL1 of each dimension.
	 */
	@Override
	public Vector normL1() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = cols.get(indices[i]).normL1;
		}

		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * normL2 of each dimension.
	 */
	@Override
	public Vector normL2() {
		int[] indices = getIndices();
		double[] data = new double[indices.length];
		for (int i = 0; i < data.length; i++) {
			data[i] = Math.sqrt(cols.get(indices[i]).squareSum);
		}
		return new SparseVector(vectorSize(), indices, data);
	}

	/**
	 * the ith entry of vector is the number of values which is not zero.
	 */
	public Vector numNonZero() {
		int colNum = vectorSize();
		double[] sum = new double[colNum];

		for (Map.Entry <Integer, VectorStatCol> entry : cols.entrySet()) {
			sum[entry.getKey()] = entry.getValue().numNonZero;
		}

		return new DenseVector(sum);
	}

	/**
	 * return the number of values which is not zero.
	 */
	public double numNonZero(int idx) {
		return numNonZero().get(idx);
	}

	/**
	 * return indices.
	 */
	private int[] getIndices() {
		Integer[] indices = cols.keySet().toArray(new Integer[0]);
		int[] out = new int[indices.length];
		for (int i = 0; i < out.length; i++) {
			out[i] = indices[i];
		}
		return out;
	}

}
